ARG IMAGE_NAME
FROM ${IMAGE_NAME}:{{ cuda.version.release_label }}-{{ cuda.cudnn.target }}-{{ cuda.os.distro }}{{ cuda.os.version }}{{ cuda.image_tag_suffix }} as base

{% if "x86_64" in cuda %}
FROM base as base-amd64

ENV NV_CUDNN_VERSION {{ cuda.x86_64.cudnn.version }}
{% if cuda.cudnn.target == "runtime" %}
{% if cuda.x86_64.cudnn.version[0] == "8" %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.x86_64.cudnn.version[0] }}-${NV_CUDNN_VERSION}.cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% else %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.x86_64.cudnn.version[0] }}-cuda-{{ cuda.version.major }}-${NV_CUDNN_VERSION}
{% endif %}

{% if "source" in cuda.x86_64.cudnn %}
    {% set hashcmd = "md5sum" -%}
    {% if "sha256sum" in cuda.x86_64.cudnn %}
        {% set hashcmd = "sha256sum" %}
    {% endif -%}
    {% if "sha256sum" in cuda.x86_64.cudnn %}
ENV NV_CUDNN_DL_HASHCMD sha256sum
ENV NV_CUDNN_DL_SUM {{ cuda.x86_64.cudnn.sha256sum }}
    {% else %}
ENV NV_CUDNN_DL_HASHCMD md5sum
ENV NV_CUDNN_DL_SUM {{ cuda.x86_64.cudnn.md5sum }}
    {% endif %}
ENV NV_CUDNN_DL_BASENAME {{ cuda.x86_64.cudnn.basename }}
ENV NV_CUDNN_DL_URL {{ cuda.x86_64.cudnn.source }}

RUN yum -y install wget

RUN wget -q ${NV_CUDNN_DL_URL} \
    && echo "${NV_CUDNN_DL_SUM}  ${NV_CUDNN_DL_BASENAME}" | ${NV_CUDNN_DL_HASHCMD} -c - \
    && rpm -i ${NV_CUDNN_DL_BASENAME} \
    && rm -f ${NV_CUDNN_DL_BASENAME} \
    && yum -y remove wget

{% endif %}
{% else %}
{% if cuda.x86_64.cudnn.version[0] == "8" %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.x86_64.cudnn.version[0] }}-${NV_CUDNN_VERSION}.cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
ENV NV_CUDNN_PACKAGE_DEV libcudnn{{ cuda.x86_64.cudnn.version[0] }}-devel-${NV_CUDNN_VERSION}.cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% else %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.x86_64.cudnn.version[0] }}-cuda-{{ cuda.version.major }}-${NV_CUDNN_VERSION}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.x86_64.cudnn.version[0] }}-devel-cuda-{{ cuda.version.major }}-${NV_CUDNN_VERSION}
{% endif %}

{% if "source" in cuda.x86_64.cudnn %}
    {% set hashcmd = "md5sum" -%}
    {% if "sha256sum" in cuda.x86_64.cudnn %}
        {% set hashcmd = "sha256sum" %}
    {% endif -%}
    {% if "sha256sum" in cuda.x86_64.cudnn %}
ENV NV_CUDNN_DL_HASHCMD sha256sum
ENV NV_CUDNN_DL_SUM {{ cuda.x86_64.cudnn.sha256sum }}
    {% else %}
ENV NV_CUDNN_DL_HASHCMD md5sum
ENV NV_CUDNN_DL_SUM {{ cuda.x86_64.cudnn.md5sum }}
    {% endif %}
ENV NV_CUDNN_DL_BASENAME {{ cuda.x86_64.cudnn.basename }}
ENV NV_CUDNN_DL_URL {{ cuda.x86_64.cudnn.source }}

RUN yum -y install wget

RUN wget -q ${NV_CUDNN_DL_URL} \
    && echo "${NV_CUDNN_DL_SUM}  ${NV_CUDNN_DL_BASENAME}" | ${NV_CUDNN_DL_HASHCMD} -c - \
    && rpm -i ${NV_CUDNN_DL_BASENAME} \
    && rm -f ${NV_CUDNN_DL_BASENAME}

{% endif %}
{% if "source" in cuda.x86_64.cudnn.dev %}
    {% set hashcmd = "md5sum" -%}
    {% if "sha256sum" in cuda.x86_64.cudnn.dev %}
        {% set hashcmd = "sha256sum" %}
    {% endif -%}
    {% if "sha256sum" in cuda.x86_64.cudnn.dev %}
ENV NV_CUDNN_DEV_DL_HASHCMD sha256sum
ENV NV_CUDNN_DEV_DL_SUM {{ cuda.x86_64.cudnn.dev.sha256sum }}
    {% else %}
ENV NV_CUDNN_DEV_DL_HASHCMD md5sum
ENV NV_CUDNN_DEV_DL_SUM {{ cuda.x86_64.cudnn.dev.md5sum }}
    {% endif %}
ENV NV_CUDNN_DEV_DL_BASENAME {{ cuda.x86_64.cudnn.dev.basename }}
ENV NV_CUDNN_DEV_DL_URL {{ cuda.x86_64.cudnn.dev.source }}

RUN wget -q ${NV_CUDNN_DEV_DL_URL} \
    && echo "${NV_CUDNN_DEV_DL_SUM}  ${NV_CUDNN_DEV_DL_BASENAME}" | ${NV_CUDNN_DEV_DL_HASHCMD} -c - \
    && rpm -i ${NV_CUDNN_DEV_DL_BASENAME} \
    && rm -f ${NV_CUDNN_DEV_DL_BASENAME} \
    && yum -y remove wget

{% endif %}
{% endif %}
{% endif %}
{% if "arm64" in cuda %}
FROM base as base-arm64

ENV NV_CUDNN_VERSION {{ cuda.arm64.cudnn.version }}
{% if cuda.cudnn.target == "runtime" %}
{% if cuda.arm64.cudnn.version[0] == "8" %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.arm64.cudnn.version[0] }}-${NV_CUDNN_VERSION}.cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% else %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.arm64.cudnn.version[0] }}-cuda-{{ cuda.version.major }}-${NV_CUDNN_VERSION}
{% endif %}

{% if "source" in cuda.arm64.cudnn %}
    {% set hashcmd = "md5sum" -%}

    {% if "sha256sum" in cuda.arm64.cudnn %}
        {% set hashcmd = "sha256sum" %}
    {% endif -%}
    {% if "sha256sum" in cuda.arm64.cudnn %}
ENV NV_CUDNN_DL_HASHCMD sha256sum
ENV NV_CUDNN_DL_SUM {{ cuda.arm64.cudnn.sha256sum }}
    {% else %}
ENV NV_CUDNN_DL_HASHCMD md5sum
ENV NV_CUDNN_DL_SUM {{ cuda.arm64.cudnn.md5sum }}
    {% endif %}
ENV NV_CUDNN_DL_BASENAME {{ cuda.arm64.cudnn.basename }}
ENV NV_CUDNN_DL_URL {{ cuda.arm64.cudnn.source }}

RUN yum -y install wget

RUN wget -q ${NV_CUDNN_DL_URL} \
    && echo "${NV_CUDNN_DL_SUM}  ${NV_CUDNN_DL_BASENAME}" | ${NV_CUDNN_DL_HASHCMD} -c - \
    && rpm -i ${NV_CUDNN_DL_BASENAME} \
    && rm -f ${NV_CUDNN_DL_BASENAME} \
    && yum -y remove wget

{% endif %}
{% else %}
{% if cuda.arm64.cudnn.version[0] == "8" %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.arm64.cudnn.version[0] }}-${NV_CUDNN_VERSION}.cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
ENV NV_CUDNN_PACKAGE_DEV libcudnn{{ cuda.arm64.cudnn.version[0] }}-devel-${NV_CUDNN_VERSION}.cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% else %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.arm64.cudnn.version[0] }}-cuda-{{ cuda.version.major }}-${NV_CUDNN_VERSION}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.arm64.cudnn.version[0] }}-devel-cuda-{{ cuda.version.major }}-${NV_CUDNN_VERSION}
{% endif %}

{% if "source" in cuda.arm64.cudnn %}
    {% set hashcmd = "md5sum" -%}
    {% if "sha256sum" in cuda.arm64.cudnn %}
        {% set hashcmd = "sha256sum" %}
    {% endif -%}
    {% if "sha256sum" in cuda.arm64.cudnn %}
ENV NV_CUDNN_DL_HASHCMD sha256sum
ENV NV_CUDNN_DL_SUM {{ cuda.arm64.cudnn.sha256sum }}
    {% else %}
ENV NV_CUDNN_DL_HASHCMD md5sum
ENV NV_CUDNN_DL_SUM {{ cuda.arm64.cudnn.md5sum }}
    {% endif %}
ENV NV_CUDNN_DL_BASENAME {{ cuda.arm64.cudnn.basename }}
ENV NV_CUDNN_DL_URL {{ cuda.arm64.cudnn.source }}

RUN yum -y install wget

RUN wget -q ${NV_CUDNN_DL_URL} \
    && echo "${NV_CUDNN_DL_SUM}  ${NV_CUDNN_DL_BASENAME}" | ${NV_CUDNN_DL_HASHCMD} -c - \
    && rpm -i ${NV_CUDNN_DL_BASENAME} \
    && rm -f ${NV_CUDNN_DL_BASENAME}

{% endif %}
{% if "source" in cuda.arm64.cudnn.dev %}
    {% set hashcmd = "md5sum" -%}
    {% if "sha256sum" in cuda.arm64.cudnn.dev %}
        {% set hashcmd = "sha256sum" %}
    {% endif -%}
    {% if "sha256sum" in cuda.arm64.cudnn.dev %}
ENV NV_CUDNN_DEV_DL_HASHCMD sha256sum
ENV NV_CUDNN_DEV_DL_SUM {{ cuda.arm64.cudnn.dev.sha256sum }}
    {% else %}
ENV NV_CUDNN_DEV_DL_HASHCMD md5sum
ENV NV_CUDNN_DEV_DL_SUM {{ cuda.arm64.cudnn.dev.md5sum }}
    {% endif %}
ENV NV_CUDNN_DEV_DL_BASENAME {{ cuda.arm64.cudnn.dev.basename }}
ENV NV_CUDNN_DEV_DL_URL {{ cuda.arm64.cudnn.dev.source }}

RUN wget -q ${NV_CUDNN_DEV_DL_URL} \
    && echo "${NV_CUDNN_DEV_DL_SUM}  ${NV_CUDNN_DEV_DL_BASENAME}" | ${NV_CUDNN_DEV_DL_HASHCMD} -c - \
    && rpm -i ${NV_CUDNN_DEV_DL_BASENAME} \
    && rm -f ${NV_CUDNN_DEV_DL_BASENAME} \
    && yum -y remove wget

{% endif %}
{% endif %}
{% endif %}
{% if "ppc64le" in cuda %}
FROM base as base-ppc64le

ENV NV_CUDNN_VERSION {{ cuda.ppc64le.cudnn.version }}
{% if cuda.cudnn.target == "runtime" %}
{% if cuda.ppc64le.cudnn.version[0] == "8" %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.ppc64le.cudnn.version[0] }}-${NV_CUDNN_VERSION}.cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% else %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.ppc64le.cudnn.version[0] }}-cuda-{{ cuda.version.major }}-${NV_CUDNN_VERSION}
{% endif %}

{% if "source" in cuda.ppc64le.cudnn %}
    {% set hashcmd = "md5sum" -%}
    {% if "sha256sum" in cuda.ppc64le.cudnn %}
        {% set hashcmd = "sha256sum" %}
    {% endif -%}
    {% if "sha256sum" in cuda.ppc64le.cudnn %}
ENV NV_CUDNN_DL_HASHCMD sha256sum
ENV NV_CUDNN_DL_SUM {{ cuda.ppc64le.cudnn.sha256sum }}
    {% else %}
ENV NV_CUDNN_DL_HASHCMD md5sum
ENV NV_CUDNN_DL_SUM {{ cuda.ppc64le.cudnn.md5sum }}
    {% endif %}
ENV NV_CUDNN_DL_BASENAME {{ cuda.ppc64le.cudnn.basename }}
ENV NV_CUDNN_DL_URL {{ cuda.ppc64le.cudnn.source }}

RUN yum -y install wget

RUN wget -q ${NV_CUDNN_DL_URL} \
    && echo "${NV_CUDNN_DL_SUM}  ${NV_CUDNN_DL_BASENAME}" | ${NV_CUDNN_DL_HASHCMD} -c - \
    && rpm -i ${NV_CUDNN_DL_BASENAME} \
    && rm -f ${NV_CUDNN_DL_BASENAME} \
    && yum -y remove wget

{% endif %}
{% else %}
{% if cuda.ppc64le.cudnn.version[0] == "8" %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.ppc64le.cudnn.version[0] }}-${NV_CUDNN_VERSION}.cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
ENV NV_CUDNN_PACKAGE_DEV libcudnn{{ cuda.ppc64le.cudnn.version[0] }}-devel-${NV_CUDNN_VERSION}.cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% else %}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.ppc64le.cudnn.version[0] }}-cuda-{{ cuda.version.major }}-${NV_CUDNN_VERSION}
ENV NV_CUDNN_PACKAGE libcudnn{{ cuda.ppc64le.cudnn.version[0] }}-devel-cuda-{{ cuda.version.major }}-${NV_CUDNN_VERSION}
{% endif %}

{% if "source" in cuda.ppc64le.cudnn %}
    {% set hashcmd = "md5sum" -%}
    {% if "sha256sum" in cuda.ppc64le.cudnn %}
        {% set hashcmd = "sha256sum" %}
    {% endif -%}
    {% if "sha256sum" in cuda.ppc64le.cudnn %}
ENV NV_CUDNN_DL_HASHCMD sha256sum
ENV NV_CUDNN_DL_SUM {{ cuda.ppc64le.cudnn.sha256sum }}
    {% else %}
ENV NV_CUDNN_DL_HASHCMD md5sum
ENV NV_CUDNN_DL_SUM {{ cuda.ppc64le.cudnn.md5sum }}
    {% endif %}
ENV NV_CUDNN_DL_BASENAME {{ cuda.ppc64le.cudnn.basename }}
ENV NV_CUDNN_DL_URL {{ cuda.ppc64le.cudnn.source }}

RUN yum -y install wget

RUN wget -q ${NV_CUDNN_DL_URL} \
    && echo "${NV_CUDNN_DL_SUM}  ${NV_CUDNN_DL_BASENAME}" | ${NV_CUDNN_DL_HASHCMD} -c - \
    && rpm -i ${NV_CUDNN_DL_BASENAME} \
    && rm -f ${NV_CUDNN_DL_BASENAME}

{% endif %}
{% if "source" in cuda.ppc64le.cudnn.dev %}
    {% set hashcmd = "md5sum" -%}
    {% if "sha256sum" in cuda.ppc64le.cudnn.dev %}
        {% set hashcmd = "sha256sum" %}
    {% endif -%}
    {% if "sha256sum" in cuda.ppc64le.cudnn.dev %}
ENV NV_CUDNN_DEV_DL_HASHCMD sha256sum
ENV NV_CUDNN_DEV_DL_SUM {{ cuda.ppc64le.cudnn.dev.sha256sum }}
    {% else %}
ENV NV_CUDNN_DEV_DL_HASHCMD md5sum
ENV NV_CUDNN_DEV_DL_SUM {{ cuda.ppc64le.cudnn.dev.md5sum }}
    {% endif %}
ENV NV_CUDNN_DEV_DL_BASENAME {{ cuda.ppc64le.cudnn.dev.basename }}
ENV NV_CUDNN_DEV_DL_URL {{ cuda.ppc64le.cudnn.dev.source }}

RUN wget -q ${NV_CUDNN_DEV_DL_URL} \
    && echo "${NV_CUDNN_DEV_DL_SUM}  ${NV_CUDNN_DEV_DL_BASENAME}" | ${NV_CUDNN_DEV_DL_HASHCMD} -c - \
    && rpm -i ${NV_CUDNN_DEV_DL_BASENAME} \
    && rm -f ${NV_CUDNN_DEV_DL_BASENAME} \
    && yum -y remove wget

{% endif %}
{% endif %}
{% endif %}

FROM base-${TARGETARCH}

ARG TARGETARCH

LABEL maintainer "NVIDIA CORPORATION <sw-cuda-installer@nvidia.com>"

LABEL com.nvidia.cudnn.version="${NV_CUDNN_VERSION}"

RUN yum install -y \
    ${NV_CUDNN_PACKAGE} \
    {% if cuda.cudnn.target == "devel" %}
    ${NV_CUDNN_PACKAGE_DEV} \
    {% endif %}
    && yum clean all \
    && rm -rf /var/cache/yum/*
